#!/usr/bin/env python
# coding=utf-8
# Copyright 2018 Google LLC & Hwalsuk Lee.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

#!/usr/bin/env python

from __future__ import absolute_import
from __future__ import division

import os

import tensorflow as tf
from google.protobuf import text_format
from compare_gan.src import gan_lib
from compare_gan.src import eval_gan_lib
from compare_gan.src import simple_task_pb2
from compare_gan.src import task_utils

from compare_gan.src import fid_score

flags = tf.flags
FLAGS = flags.FLAGS
logging = tf.logging

flags.DEFINE_string("workdir", "../../tasks/", "The working directory for all tasks.")
flags.DEFINE_integer("task_num", 1, "The task number to use.")
flags.DEFINE_string("command", "all", "What command to execute: all, train, eval.")
flags.DEFINE_bool("enable_tf_profile", False, "Whether to run TFProf.")

def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)
  task_workdir = os.path.join(FLAGS.workdir, str(FLAGS.task_num))
  task = simple_task_pb2.Task()
  with open(os.path.join(task_workdir, "task"), "r") as f:
    text_format.Parse(f.read(), task)

  options = task_utils.ParseOptions(task)
  options["___workdir"] = task_workdir

  task_string = text_format.MessageToString(task)
  logging.info("\nWill run task\n%s\n\n", task_string)

  if FLAGS.command in ["all", "train"]:
    logging.info("\n====== Running training ======\n\n")
    gan_lib.run_with_options(options, task_workdir)

  if FLAGS.command in ["all", "eval"]:
    logging.info("\n====== Running evaluation ======\n\n")
    eval_gan_lib.RunTaskEval(
        options, task_workdir,
        tf.contrib.gan.eval.get_graph_def_from_disk(
            os.path.join(FLAGS.dataset_root,
                         "inceptionv1_for_inception_score.pb")))

  logging.info("\n====== Task finished ======\n\n")


if __name__ == "__main__":
  gpu_list = [3]
  os.environ['CUDA_VISIBLE_DEVICES'] = ''.join(str(gpu) + ',' for gpu in gpu_list)

  tf.app.run()
